# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for open_spiel.python.algorithms.exploitability."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import absltest
from absl.testing import parameterized

from open_spiel.python import policy
from open_spiel.python.algorithms import exploitability
from open_spiel.python.algorithms import policy_utils
from open_spiel.python.games import data
import pyspiel


class ExploitabilityTest(parameterized.TestCase):

  def test_exploitability_on_kuhn_poker_uniform_random(self):
    # NashConv of uniform random test_policy from (found on Google books):
    # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    expected_nash_conv = 11 / 12
    self.assertAlmostEqual(
        exploitability.exploitability(game, test_policy),
        expected_nash_conv / 2)

  def test_kuhn_poker_uniform_random_best_response_pid0(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    results = exploitability.best_response(game, test_policy, player_id=0)
    self.assertEqual(
        results["best_response_action"],
        {
            "0": 1,  # Bet in case opponent folds when winning
            "1": 1,  # Bet in case opponent folds when winning
            "2": 0,  # Both equally good (we return the lowest action)
            # Some of these will never happen under the best-response policy,
            # but we have computed best-response actions anyway.
            "0pb": 0,  # Fold - we're losing
            "1pb": 1,  # Call - we're 50-50
            "2pb": 1,  # Call - we've won
        })
    self.assertGreater(results["nash_conv"], 0.1)

  def test_kuhn_poker_uniform_random_best_response_pid1(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    results = exploitability.best_response(game, test_policy, player_id=1)
    self.assertEqual(
        results["best_response_action"],
        {
            # Bet is always best
            "0p": 1,
            "1p": 1,
            "2p": 1,
            # Call unless we know we're beaten
            "0b": 0,
            "1b": 1,
            "2b": 1,
        })
    self.assertGreater(results["nash_conv"], 0.1)

  def test_kuhn_poker_uniform_random(self):
    # NashConv of uniform random test_policy from (found on Google books):
    # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 11 / 12)

  def test_kuhn_poker_always_fold(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.FirstActionPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 2)

  def test_kuhn_poker_optimal(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = data.kuhn_nash_equilibrium(alpha=0.2)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 0)

  def test_leduc_poker_uniform_random(self):
    # NashConv taken from independent implementations
    game = pyspiel.load_game("leduc_poker")
    test_policy = policy.UniformRandomPolicy(game)
    self.assertAlmostEqual(
        exploitability.nash_conv(game, test_policy), 4.747222222222222)

  def test_leduc_poker_always_fold(self):
    game = pyspiel.load_game("leduc_poker")
    test_policy = policy.FirstActionPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 2)

  # Values for uniform policies taken from
  # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
  # (including multiplayer games below). However, the value for Leduc against
  # the uniform test_policy is wrong in the paper. This has been independently
  # verified in a number of independent code bases. The 4.7472 value is correct.
  # Value for AlwaysFold is trivial: if you
  # always fold, you win 0 chips, but if you switch to AlwaysBet, you win 1
  # chip everytime if playing against a player who always folds.
  @parameterized.parameters(
      ("kuhn_poker", policy.UniformRandomPolicy, 0.9166666666666666),
      ("kuhn_poker", policy.FirstActionPolicy, 2.),
      ("kuhn_poker", lambda _: data.kuhn_nash_equilibrium(alpha=0.2), 0.),
      ("leduc_poker", policy.FirstActionPolicy, 2.),
      ("leduc_poker", policy.UniformRandomPolicy, 4.7472222222222),
  )
  def test_2p_nash_conv(self, game_name, policy_func, expected):
    game = pyspiel.load_game(game_name)
    self.assertAlmostEqual(
        exploitability.nash_conv(game, policy_func(game)), expected)

  @parameterized.parameters(3, 4)
  def test_kuhn_poker_uniform_random_nash_conv(self, num_players):
    game = pyspiel.load_game("kuhn_poker",
                             {"players": pyspiel.GameParameter(num_players)})
    test_policy = policy.UniformRandomPolicy(game)
    self.assertGreater(exploitability.nash_conv(game, test_policy), 2)

  @parameterized.parameters(("kuhn_poker", 2), ("kuhn_poker", 3),
                            ("kuhn_poker", 4))
  def test_python_same_as_cpp_for_multiplayer_uniform_random_nash_conv(
      self, game_name, num_players):
    game = pyspiel.load_game(game_name,
                             {"players": pyspiel.GameParameter(num_players)})

    # TabularPolicy defaults to being a uniform random policy.
    test_policy = policy.TabularPolicy(game)
    python_nash_conv = exploitability.nash_conv(game, test_policy)
    cpp_nash_conv = pyspiel.nash_conv(
        game, policy_utils.policy_to_dict(test_policy, game))
    self.assertAlmostEqual(python_nash_conv, cpp_nash_conv)


if __name__ == "__main__":
  absltest.main()
